{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade transformers -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['USE_TORCH'] = 'True'  # To use transformers library in TPU\n",
    "os.environ['XLA_USE_BF16'] = 'True'\n",
    "os.environ['PJRT_DEVICE'] = 'TPU'\n",
    "os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'\n",
    "os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'\n",
    "!export HF_HUB_CACHE=\"/mnt/persistent-disk/hf/\"\n",
    "!export HF_HOME=\"/mnt/persistent-disk/hf/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 13.608021,
     "end_time": "2023-11-04T12:34:21.013846",
     "exception": false,
     "start_time": "2023-11-04T12:34:07.405825",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import contextlib\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "import torch.distributed as dist\n",
    "import torch_xla.core.xla_model as xm\n",
    "import torch_xla.runtime as xr\n",
    "xr.use_spmd()\n",
    "\n",
    "import torch_xla.experimental.xla_sharding as xs \n",
    "from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor\n",
    "from torch_xla.experimental.xla_sharding import Mesh\n",
    "\n",
    "# from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP, checkpoint_module\n",
    "\n",
    "import torch_xla.distributed.xla_multiprocessing as xmp\n",
    "import torch_xla.distributed.parallel_loader as pl\n",
    "import torch_xla.test.test_utils as test_utils\n",
    "\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator\n",
    "from datasets import Dataset, load_dataset, concatenate_datasets\n",
    "from peft import LoraConfig, TaskType, get_peft_model\n",
    "\n",
    "import logging\n",
    "logging.getLogger(\"datasets\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"transformers\").setLevel(logging.WARNING)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert xr.is_spmd()==True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import importlib\n",
    "sys.path.append('')\n",
    "model_partitioning = importlib.import_module('trainer_lib.model_partitioning')\n",
    "importlib.reload(model_partitioning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import torch_xla.core.xla_model as xm\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "from safetensors.torch import load_file"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This notebook can be used to train any of the 70B, 80B models!\n",
    "supported_models = [\n",
    "    \"meta-llama/Meta-Llama-3-70B\",\n",
    "    \"meta-llama/Meta-Llama-3-70B-Instruct\",\n",
    "]\n",
    "\n",
    "# Select a supported model from above list to use!\n",
    "MODEL_NAME = \"meta-llama/Meta-Llama-3-70B\"\n",
    "HUGGINGFACE_TOKEN = input(\"**INPUT** you huggingface token: \") # YOUR_HF_TOKEN\n",
    "DEBUG_MODE = False\n",
    "\n",
    "TRAINER_CONFIG = {\n",
    "    \"epochs\": 1,\n",
    "    \"batch_size\": 1,\n",
    "    \"max_length\": 512,\n",
    "    \n",
    "    \"lr\": 5e-5,\n",
    "    \"logging_interval\": 5,  # logs every 5 steps\n",
    "    \n",
    "    \"lora_rank\": 8,\n",
    "    \"lora_alpha\": 32,\n",
    "    \"lora_dropout\": 0.1,\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "from huggingface_hub import login\n",
    "login(token=HUGGINGFACE_TOKEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def apply_lora(*, model, lora_rank=None, lora_alpha=None, lora_dropout=None):\n",
    "    \"\"\"Applies LoRA configuration to the model.\"\"\"\n",
    "    peft_config = LoraConfig(\n",
    "        task_type=TaskType.CAUSAL_LM,\n",
    "        inference_mode=False,\n",
    "        r=8 if not lora_rank else lora_rank,\n",
    "        lora_alpha=32 if not lora_alpha else lora_alpha,\n",
    "        lora_dropout=0.1 if not lora_dropout else lora_dropout,\n",
    "    )\n",
    "    model = get_peft_model(model, peft_config)\n",
    "    model.print_trainable_parameters()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def init_model(*, model_name, hugging_face_token):\n",
    "    \"\"\"Downloads and initializes the model.\"\"\"\n",
    "    from huggingface_hub import hf_hub_download\n",
    "    \n",
    "    config = AutoConfig.from_pretrained(\n",
    "        model_name,\n",
    "        token=hugging_face_token)\n",
    "    with torch.device('meta'):\n",
    "        model = AutoModelForCausalLM.from_config(config)\n",
    "\n",
    "    # Load and merge state dicts.\n",
    "    from safetensors.torch import load_file\n",
    "    num_safetensor_shards = 30\n",
    "    state_dict = {}\n",
    "    for i in range(1, num_safetensor_shards + 1):\n",
    "        file = f\"model-{i:05d}-of-{num_safetensor_shards:05d}.safetensors\"\n",
    "        print(f\"downloading file {file} from HF...\")\n",
    "        file_path = hf_hub_download(repo_id=model_name, filename=file)\n",
    "        state_dict.update(load_file(file_path))\n",
    "\n",
    "    for name, param in model.named_parameters():\n",
    "        if name in state_dict:\n",
    "            param.data.copy_(state_dict[name].data)\n",
    "\n",
    "    model.to_empty(device='cpu')\n",
    "    \n",
    "    model.load_state_dict(state_dict, strict=False, assign=True)\n",
    "\n",
    "    model = apply_lora(\n",
    "        model=model,\n",
    "        lora_rank=TRAINER_CONFIG[\"lora_rank\"],\n",
    "        lora_alpha=TRAINER_CONFIG[\"lora_alpha\"],\n",
    "        lora_dropout=TRAINER_CONFIG[\"lora_dropout\"],\n",
    "    )\n",
    "    \n",
    "    tokenizer = AutoTokenizer.from_pretrained(\n",
    "        model_name, \n",
    "        token=hugging_face_token\n",
    "    )\n",
    "\n",
    "    if not tokenizer.pad_token:\n",
    "        tokenizer.pad_token = tokenizer.eos_token\n",
    "        config.pad_token_id = tokenizer.pad_token_id\n",
    "        \n",
    "    return model, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 0.547357,
     "end_time": "2023-11-04T12:34:26.034497",
     "exception": false,
     "start_time": "2023-11-04T12:34:25.48714",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "def apply_spmd(*, model, mesh):\n",
    "    # Apply on layers within model.\n",
    "    model_partitioning_util.partition_model(model, mesh)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Configure dataset pipeline for your model\n",
    "\n",
    "For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.\n",
    "\n",
    "It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(*, tokenizer, batch_size=None, max_length=None, debug_mode=False):\n",
    "    # Define Alpaca prompt template\n",
    "    alpaca_prompt = \"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "    \n",
    "    ### Instruction: {}\n",
    "    \n",
    "    ### Input: {}\n",
    "    \n",
    "    ### Response: {}\"\"\"\n",
    "    \n",
    "    EOS_TOKEN = tokenizer.eos_token\n",
    "    \n",
    "    # Define formatting function.\n",
    "    def _format_prompts(examples):\n",
    "        instructions = examples[\"instruction\"]\n",
    "        inputs = examples[\"input\"]\n",
    "        outputs = examples[\"output\"]\n",
    "        texts = []\n",
    "        for instruction, input, output in zip(instructions, inputs, outputs):\n",
    "            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN\n",
    "            texts.append(text)\n",
    "        return {\"text\": texts}\n",
    "\n",
    "    # Tokenize the dataset.\n",
    "    def _tokenize(examples):\n",
    "        # Tokenized is list within list. Compute labels for causalLM by shifting input_id; \n",
    "        # consequently truncate input_id to penultimate position.\n",
    "        tokenized = tokenizer(examples[\"text\"], truncation=True, padding=\"max_length\", max_length=512+1 if not max_length else max_length+1)\n",
    "        labels = tokenized['input_ids'].copy()\n",
    "        tokenized['labels'] = [label[1:] for label in labels]\n",
    "        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]\n",
    "        return tokenized\n",
    "\n",
    "    # Load and preprocess the dataset.\n",
    "    dataset = load_dataset(\"yahma/alpaca-cleaned\", split=\"train\")\n",
    "    if debug_mode:\n",
    "        dataset = dataset.select(range(32)) # Use just 32 exampfor faster iteration\n",
    "    dataset = dataset.map(_format_prompts, batched=True)\n",
    "\n",
    "    # Create train and test dataset.\n",
    "    ds = dataset.train_test_split(test_size=0.15)\n",
    "    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)\n",
    "\n",
    "    # Create DataLoader\n",
    "    train_dataloader = torch.utils.data.DataLoader(\n",
    "        ds['train'],\n",
    "        shuffle=True,\n",
    "        batch_size=1 if not batch_size else batch_size,\n",
    "        collate_fn=default_data_collator,\n",
    "    )\n",
    "    \n",
    "    test_dataloader = torch.utils.data.DataLoader(\n",
    "        ds['test'],\n",
    "        shuffle=True,\n",
    "        batch_size=1 if not batch_size else batch_size,\n",
    "        collate_fn=default_data_collator,\n",
    "    )\n",
    "\n",
    "    return train_dataloader, test_dataloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model\n",
    "\n",
    "Now let's train the model. We are using PyTorch XLA's Fully Sharded Data Parallel (FSDP) to distribute the model across the 8 TPU cores available on TPU v3-8. This approach allows for efficient training on TPU hardware. We also utilize PyTorch/XLA's MpDeviceLoader to efficiently load data onto the TPU cores.\n",
    "\n",
    "**NOTE:** It's important to note that the **first step of training will be slow**. This is because XLA takes time initially to compile the computational graph. However, once the compilation is complete, subsequent steps will run much faster using compiled+cached graph, and leveraging the full power of the all TPU cores for accelerated training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_training_update(step,\n",
    "                          loss,\n",
    "                          epoch=None):\n",
    "    \"\"\"Prints the training metrics at a given step.\"\"\"\n",
    "    if xm.is_master_ordinal():  # Only print on the master device\n",
    "        update_data = [\n",
    "            'Training',\n",
    "            f'Epoch={epoch}' if epoch is not None else None,\n",
    "            f'Step={step}',\n",
    "            f'Loss={loss:.5f}',\n",
    "        ]\n",
    "        print(' | '.join(item for item in update_data if item), flush=True)\n",
    "        print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model, tokenizer = init_model(\n",
    "        model_name=MODEL_NAME, hugging_face_token=HUGGINGFACE_TOKEN\n",
    ")\n",
    "model = xmp.MpModelWrapper(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(index):\n",
    "    global model, tokenizer\n",
    "\n",
    "    torch.manual_seed(99)\n",
    "    device = xm.xla_device()\n",
    "    model = model.to(device)\n",
    "\n",
    "    # Create a mesh for the model partitioning.\n",
    "    num_devices = xr.global_runtime_device_count()\n",
    "    mesh_shape = (1, num_devices, 1)\n",
    "    device_ids = np.array(range(num_devices))\n",
    "    mesh = Mesh(device_ids, mesh_shape, (\"dp\", \"fsdp\", \"mp\"))\n",
    "    \n",
    "    # Partition the model using SPMD.\n",
    "    model_partitioning.partition_model(model=model, mesh=mesh)\n",
    "    \n",
    "    # Configure the training loop.\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=TRAINER_CONFIG[\"lr\"])\n",
    "\n",
    "    train_dataloader, test_dataloader = get_dataset(\n",
    "        tokenizer=tokenizer,\n",
    "        batch_size=TRAINER_CONFIG[\"batch_size\"],\n",
    "        max_length=TRAINER_CONFIG[\"max_length\"],\n",
    "    )\n",
    "    train_dataloader = pl.MpDeviceLoader(\n",
    "        train_dataloader, \n",
    "        device\n",
    "    ) \n",
    "    test_dataloader = pl.MpDeviceLoader(\n",
    "        test_dataloader, \n",
    "        device\n",
    "    )\n",
    "\n",
    "    for epoch in range(TRAINER_CONFIG[\"epochs\"]):\n",
    "        xm.master_print(f\"Epoch {epoch} train begin {test_utils.now()}\")\n",
    "        tracker = xm.RateTracker()\n",
    "        \n",
    "        model.train()\n",
    "        for step, batch in enumerate(train_dataloader):\n",
    "            if step>1:\n",
    "                break\n",
    "                \n",
    "            optimizer.zero_grad()\n",
    "            \n",
    "            input_ids, attention_mask, labels = (\n",
    "                batch[\"input_ids\"],\n",
    "                batch[\"attention_mask\"],\n",
    "                batch[\"labels\"],\n",
    "            )\n",
    "            xs.mark_sharding(input_ids, mesh, (0, 1))\n",
    "            xs.mark_sharding(attention_mask, mesh, (0, 1))\n",
    "            xs.mark_sharding(labels, mesh, (0, 1))\n",
    "            \n",
    "            output = model(\n",
    "                input_ids=input_ids, attention_mask=attention_mask, labels=labels\n",
    "            )\n",
    "            loss = output.loss\n",
    "            loss.backward()\n",
    "            \n",
    "            optimizer.step()\n",
    "            xm.mark_step()\n",
    "\n",
    "            if step%TRAINER_CONFIG[\"logging_interval\"]==0:\n",
    "                loss_cpu = loss.item()\n",
    "                xm.add_step_closure(\n",
    "                    print_training_update,\n",
    "                    args=(step, loss_cpu, epoch)\n",
    "                )\n",
    "\n",
    "        # UNCOMMENT BELOW TO RUN EVAL.\n",
    "        # model.eval()\n",
    "        # eval_loss = 0\n",
    "        # with torch.no_grad():\n",
    "        #     for step, batch in enumerate(test_dataloader):\n",
    "        #         input_ids, attention_mask, labels = (\n",
    "        #             batch[\"input_ids\"],\n",
    "        #             batch[\"attention_mask\"],\n",
    "        #             batch[\"labels\"],\n",
    "        #         )\n",
    "        #         xs.mark_sharding(input_ids, mesh, (0, 1))\n",
    "        #         xs.mark_sharding(attention_mask, mesh, (0, 1))\n",
    "        #         xs.mark_sharding(labels, mesh, (0, 1))\n",
    "                \n",
    "        #         output = model(\n",
    "        #             input_ids=input_ids, attention_mask=attention_mask, labels=labels\n",
    "        #         )\n",
    "        #         eval_loss += output.loss.item()\n",
    "        # avg_eval_loss = eval_loss / len(test_dataloader)\n",
    "        # xm.add_step_closure(\n",
    "        #     lambda: print(f\"Eval loss: {avg_eval_loss:.4f}\"),\n",
    "        # )\n",
    "    result = {'device': xm.get_ordinal(), 'loss': loss.item()}\n",
    "    print(f\"Finished training!\")\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    xmp.spawn(train, args=(), start_method=\"fork\")\n",
    "except Exception as e:\n",
    "    # Catch the expected error of obtaining results from multiple TPU chips when starting distributed training from a notebook.\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Export the model to HuggingFace Hub\n",
    "Uncoment the following cell to push the model to HuggingFace Hub."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from huggingface_hub import login\n",
    "\n",
    "login(token=HUGGINGFACE_TOKEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "papermill": {
     "duration": 388.410448,
     "end_time": "2023-11-04T13:21:54.038795",
     "exception": false,
     "start_time": "2023-11-04T13:15:25.628347",
     "status": "completed"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "HUGGINGFACE_USERNAME = input(\"Please provide your HUGGINGFACE_USERNAME: \")\n",
    "LOCAL_SAVE_DIR = \"/mnt/persistent-disk/felafax-finetuned-model/\"  \n",
    "\n",
    "model = model.to(device='cpu')\n",
    "merged_model = model.merge_and_unload()\n",
    "\n",
    "print(\"Saving model locally...\")\n",
    "merged_model.save_pretrained(\n",
    "    LOCAL_SAVE_DIR,\n",
    "    max_shard_size=\"5GB\",\n",
    "    safe_serialization=True\n",
    ")\n",
    "\n",
    "\n",
    "print(\"Uncomment below code if you want to upload to HF.\")\n",
    "# print(\"Uploading to HF...\")\n",
    "# from huggingface_hub import HfApi\n",
    "# api = HfApi()\n",
    "# api.upload_folder(\n",
    "#     folder_path=LOCAL_SAVE_DIR,\n",
    "#     repo_id=f\"{HUGGINGFACE_USERNAME}/felafax-llama3-finetuned-70B\",\n",
    "#     repo_type=\"model\",\n",
    "#     ignore_patterns=[\".*\"],\n",
    "#     token=HUGGINGFACE_TOKEN\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "kaggle": {
   "accelerator": "tpu1vmV38",
   "dataSources": [
    {
     "databundleVersionId": 7516023,
     "sourceId": 61542,
     "sourceType": "competition"
    },
    {
     "datasetId": 3555678,
     "isSourceIdPinned": true,
     "sourceId": 6196932,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3863727,
     "sourceId": 6703755,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3936750,
     "sourceId": 6847931,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3946973,
     "sourceId": 6867914,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3937441,
     "sourceId": 6868189,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3949797,
     "sourceId": 6873567,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3942644,
     "sourceId": 6890527,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3937250,
     "sourceId": 7017419,
     "sourceType": "datasetVersion"
    },
    {
     "datasetId": 3944051,
     "sourceId": 7060310,
     "sourceType": "datasetVersion"
    }
   ],
   "dockerImageVersionId": 30529,
   "isGpuEnabled": false,
   "isInternetEnabled": true,
   "language": "python",
   "sourceType": "notebook"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
